package com.chatplus.application.aiprocessor.listener;

import cn.bugstack.openai.executor.parameter.*;
import cn.hutool.core.thread.ThreadUtil;
import cn.hutool.extra.spring.SpringUtil;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.chatplus.application.aiprocessor.function.TriggerFunctionService;
import com.chatplus.application.aiprocessor.provider.TriggerFunctionServiceProvider;
import com.chatplus.application.aiprocessor.util.ChatWebSocketUtil;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import com.chatplus.application.common.util.PlusJsonUtils;
import com.chatplus.application.domain.entity.functions.FunctionEntity;
import com.chatplus.application.domain.request.ws.ChatWebSocketRequest;
import com.chatplus.application.enumeration.MessageTypeEnum;
import com.chatplus.application.service.account.UserProductLogService;
import com.chatplus.application.service.functions.FunctionService;
import okhttp3.Response;
import okhttp3.internal.http2.StreamResetException;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;

import java.util.List;
import java.util.Objects;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * OpenAI监听器
 */
public class CommonEventSourceListener extends EventSourceListener {
    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(CommonEventSourceListener.class);
    private final StringBuilder reply = new StringBuilder();
    // 函数参数
    private final StringBuilder arguments = new StringBuilder();
    // 函数名称
    private final StringBuilder name = new StringBuilder();

    private boolean isEnd = false;

    private final String sessionId;
    private final ChatWebSocketRequest request;

    public CommonEventSourceListener(ChatWebSocketRequest request) {
        sessionId = request.getSessionId();
        this.request = request;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void onOpen(EventSource eventSource, @NotNull Response response) {
        LOGGER.message("连接成功").context("eventSource", eventSource.request()).context("response", response).info();
        ChatWebSocketUtil.replyStreamMessage(sessionId, MessageTypeEnum.WS_START.getValue());
        ScheduledThreadPoolExecutor scheduledExecutor = ThreadUtil.createScheduledExecutor(1);
        scheduledExecutor.schedule(() -> {
            if (StringUtils.isEmpty(reply) && !isEnd) {
                ChatWebSocketUtil.replyStreamMessage(sessionId, "数据接收存在延迟，请耐心等待，如长时间无应答请联系管理员或者重新生成对话！");
                LOGGER.message("连接获取数据延迟发通知").info();
            }
            scheduledExecutor.shutdownNow();
        }, 30, TimeUnit.SECONDS);
        super.onOpen(eventSource, response);
    }

    private final List<String> stopFlag = List.of("stop", "function_call", "[DONE]");

    /**
     * {@inheritDoc}
     */
    @Override
    public void onEvent(@NotNull EventSource eventSource, String id, String type, @NotNull String data) {
        LOGGER.message("AI返回数据").context("data", data).debug();
        if ("[DONE]".equalsIgnoreCase(data)) {
            LOGGER.message("AI 应答完成").context("data", data).info();
            return;
        }
        StringBuilder content = new StringBuilder();
        CompletionResponse chatCompletionResponse = PlusJsonUtils.parseObject(data, CompletionResponse.class);
        if (chatCompletionResponse == null || chatCompletionResponse.getChoices() == null) {
            return;
        }
        List<ChatChoice> choices = chatCompletionResponse.getChoices();
        for (ChatChoice chatChoice : choices) {
            Message delta = chatChoice.getDelta();
            // 如果是有函数的话，就把函数的名称和参数拼接起来
            FunctionCall functionCall = delta.getFunctionCall();
            if (Objects.nonNull(functionCall)) {
                if (StringUtils.isNotEmpty(functionCall.getName())) {
                    name.append(functionCall.getName());
                }
                arguments.append(functionCall.getArguments());
            }
            if (Objects.nonNull(delta.getContent())) {
                content.append(delta.getContent());
            }
            // 应答完成
            String finishReason = chatChoice.getFinishReason();
            if (StringUtils.isNoneBlank(finishReason)
                    && stopFlag.stream().anyMatch(finishReason::equalsIgnoreCase)) {
                Usage usage = chatCompletionResponse.getUsage();
                if (Objects.nonNull(usage)) {
                    request.setReplyToken(usage.getCompletionTokens());
                }
                LOGGER.message("AI 应答完成").info();
                break;
            }
        }
        if (StringUtils.isNotEmpty(content.toString())) {
            ChatWebSocketUtil.replyStreamMessage(sessionId, content.toString());
            reply.append(content);
        }
    }

    @Override
    public void onClosed(@NotNull EventSource eventSource) {
        super.onClosed(eventSource);
        try {
            request.setReply(reply.toString());
            if (StringUtils.isEmpty(name)) {
                request.setUseContext(true);
                ChatWebSocketUtil.saveChatHistory(request);
            } else {
                functionAction(name.toString(), arguments.toString());
            }
        } finally {
            ChatWebSocketUtil.replyStreamMessage(sessionId, MessageTypeEnum.WS_END.getValue());
            // 只有在会话真正结束的时候才会做扣减操作
            UserProductLogService userProductLogService = SpringUtil.getBean(UserProductLogService.class);
            userProductLogService.reducePower(request.getUserId(), request.getPower(), request.getPlatform(), request.getHistoryItemId());
            isEnd = true;
        }
        LOGGER.message("关闭连接").info();
    }

    @Override
    public void onFailure(@NotNull EventSource eventSource, Throwable t, Response response) {
        if (Objects.isNull(response)) {
            return;
        }
        if (t instanceof StreamResetException streamResetException &&
                streamResetException.getMessage().equals("stream was reset: CANCEL")) {
            ChatWebSocketUtil.replyStreamMessage(sessionId, MessageTypeEnum.WS_END.getValue());
        } else {
            LOGGER.message("连接异常").context("response", response).exception(t).error();
            ChatWebSocketUtil.replyFullMessage(sessionId, "服务端发生异常，等待管理员修复!", true);
        }
        if (isEnd) {
            request.setUseContext(true);
            ChatWebSocketUtil.saveChatHistory(request);
        }
        super.onFailure(eventSource, t, response);
        isEnd = true;
        eventSource.cancel();
    }

    public void functionAction(String name, String arguments) {
        LOGGER.message("调用函数工具").context("name", name).context("arguments", arguments).info();
        FunctionService functionService = SpringUtil.getBean(FunctionService.class);
        FunctionEntity functionEntity = functionService.getOne(Wrappers.<FunctionEntity>lambdaQuery().eq(FunctionEntity::getName, name));
        if (Objects.isNull(functionEntity)) {
            LOGGER.message("函数不存在或者返回的参数错误").context("name", name).warn();
            ChatWebSocketUtil.replyStreamMessage(sessionId, "触发函数工具不存在或者返回的参数错误，请等待管理员修复再重试");
            return;
        }
        try {
            // "正在调用工具 `%s` 作答 ...\n\n"
            String prompt = String.format("正在调用工具 `%s` 作答 ...%n%n", functionEntity.getLabel());
            ChatWebSocketUtil.replyStreamMessage(sessionId, prompt);
            TriggerFunctionServiceProvider triggerFunctionServiceProvider = SpringUtil.getBean(TriggerFunctionServiceProvider.class);
            TriggerFunctionService triggerFunctionService = triggerFunctionServiceProvider.getTriggerFunctionService(name);
            String body = triggerFunctionService.executeReturnMarkdown(functionEntity, arguments);
            if (StringUtils.isEmpty(body)) {
                ChatWebSocketUtil.replyStreamMessage(sessionId, "调用函数工具出错：返回结果为空");
                return;
            }
            ChatWebSocketUtil.replyStreamMessage(sessionId, body);
            request.setReply(body);
            request.setUseContext(false);
            ChatWebSocketUtil.saveChatHistory(request);
        } catch (Exception e) {
            ChatWebSocketUtil.replyStreamMessage(sessionId, "调用函数工具出错，请等待管理员修复再重试");
            LOGGER.message("调用函数工具出错")
                    .context("name", name)
                    .context("arguments", arguments)
                    .exception(e)
                    .error();
        }
    }
}
